//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#include <cuda/__driver/driver_api.h>
#include <cuda/__runtime/ensure_current_context.h>
#include <cuda/devices>
#include <cuda/std/cstddef>
#include <cuda/std/optional>
#include <cuda/std/type_traits>
#include <cuda/std/utility>

#include <cuda/experimental/kernel.cuh>

#include <testing.cuh>

// extern "C" __constant__ int const_data;
//
// extern "C" __global__ void kernel_ptx1(int* array, int n)
// {
//   __shared__ int shared[32];
//   int tid = blockDim.x * blockIdx.x + threadIdx.x;
//   if (tid < n)
//   {
//     shared[threadIdx.x] = array[tid];
//     __syncthreads();
//     array[tid] = shared[threadIdx.x + 1 % 32] + const_data;
//   }
// }
//
// extern "C" __global__ void kernel_ptx2(float* array, int n)
// {
//   __shared__ float shared[32];
//   int tid = blockDim.x * blockIdx.x + threadIdx.x;
//   if (tid < n)
//   {
//     shared[threadIdx.x] = array[tid];
//     __syncthreads();
//     array[tid] = shared[threadIdx.x + 1 % 32] + static_cast<float>(const_data);
//   }
// }

constexpr char kernel_ptx_src[] = R"(
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-32267302
// Cuda compilation tools, release 12.0, V12.0.140
// Based on NVVM 7.0.1
//

.version 8.0
.target sm_75
.address_size 64

	// .globl	kernel_ptx1
.const .align 4 .u32 const_data;
// _ZZ11kernel_ptx1E6shared has been demoted
// _ZZ11kernel_ptx2E6shared has been demoted

.visible .entry kernel_ptx1(
	.param .u64 kernel_ptx1_param_0,
	.param .u32 kernel_ptx1_param_1
)
{
	.reg .pred 	%p<2>;
	.reg .b32 	%r<13>;
	.reg .b64 	%rd<5>;
	// demoted variable
	.shared .align 4 .b8 _ZZ11kernel_ptx1E6shared[128];

	ld.param.u64 	%rd1, [kernel_ptx1_param_0];
	ld.param.u32 	%r3, [kernel_ptx1_param_1];
	mov.u32 	%r4, %ntid.x;
	mov.u32 	%r5, %ctaid.x;
	mov.u32 	%r1, %tid.x;
	mad.lo.s32 	%r2, %r4, %r5, %r1;
	setp.ge.s32 	%p1, %r2, %r3;
	@%p1 bra 	$L__BB0_2;

	cvta.to.global.u64 	%rd2, %rd1;
	mul.wide.s32 	%rd3, %r2, 4;
	add.s64 	%rd4, %rd2, %rd3;
	ld.global.u32 	%r6, [%rd4];
	shl.b32 	%r7, %r1, 2;
	mov.u32 	%r8, _ZZ11kernel_ptx1E6shared;
	add.s32 	%r9, %r8, %r7;
	st.shared.u32 	[%r9], %r6;
	bar.sync 	0;
	ld.const.u32 	%r10, [const_data];
	ld.shared.u32 	%r11, [%r9+4];
	add.s32 	%r12, %r10, %r11;
	st.global.u32 	[%rd4], %r12;

$L__BB0_2:
	ret;

}
	// .globl	kernel_ptx2
.visible .entry kernel_ptx2(
	.param .u64 kernel_ptx2_param_0,
	.param .u32 kernel_ptx2_param_1
)
{
	.reg .pred 	%p<2>;
	.reg .f32 	%f<5>;
	.reg .b32 	%r<10>;
	.reg .b64 	%rd<5>;
	// demoted variable
	.shared .align 4 .b8 _ZZ11kernel_ptx2E6shared[128];

	ld.param.u64 	%rd1, [kernel_ptx2_param_0];
	ld.param.u32 	%r3, [kernel_ptx2_param_1];
	mov.u32 	%r4, %ntid.x;
	mov.u32 	%r5, %ctaid.x;
	mov.u32 	%r1, %tid.x;
	mad.lo.s32 	%r2, %r4, %r5, %r1;
	setp.ge.s32 	%p1, %r2, %r3;
	@%p1 bra 	$L__BB1_2;

	cvta.to.global.u64 	%rd2, %rd1;
	mul.wide.s32 	%rd3, %r2, 4;
	add.s64 	%rd4, %rd2, %rd3;
	ld.global.f32 	%f1, [%rd4];
	shl.b32 	%r6, %r1, 2;
	mov.u32 	%r7, _ZZ11kernel_ptx2E6shared;
	add.s32 	%r8, %r7, %r6;
	st.shared.f32 	[%r8], %f1;
	bar.sync 	0;
	ld.const.u32 	%r9, [const_data];
	cvt.rn.f32.s32 	%f2, %r9;
	ld.shared.f32 	%f3, [%r8+4];
	add.f32 	%f4, %f3, %f2;
	st.global.f32 	[%rd4], %f4;

$L__BB1_2:
	ret;

}
)";

#if _CCCL_CTK_AT_LEAST(12, 1)
extern "C" __global__ void kernel_rt(int* data, int size)
{
  const auto idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < size)
  {
    data[idx] += 1;
  }
}
#endif // _CCCL_CTK_AT_LEAST(12, 1)

template <const auto& Attr, ::CUfunction_attribute ExpectedAttr, class ExpectedResult, class Signature>
[[maybe_unused]] auto test_kernel_attribute(
  cudax::kernel_ref<Signature> kernel,
  cuda::device_ref dev,
  cuda::std::optional<ExpectedResult> expected_value = cuda::std::nullopt)
{
  STATIC_REQUIRE(Attr == ExpectedAttr);

  auto result = kernel.attribute(Attr, dev);
  STATIC_REQUIRE(::cuda::std::is_same_v<decltype(result), ExpectedResult>);
  if (expected_value.has_value())
  {
    CUDAX_REQUIRE(result == expected_value.value());
  }
  CUDAX_REQUIRE(result == Attr(kernel, dev));
  return result;
}

C2H_CCCLRT_TEST("Kernel reference", "[kernel_ref]")
{
  CUlibrary lib = _CUDA_DRIVER::__libraryLoadData(kernel_ptx_src, nullptr, nullptr, 0, nullptr, nullptr, 0);

  CUkernel kernel_ptx1_handle = _CUDA_DRIVER::__libraryGetKernel(lib, "kernel_ptx1");
  CUkernel kernel_ptx2_handle = _CUDA_DRIVER::__libraryGetKernel(lib, "kernel_ptx2");

  cuda::device_ref device{0};
  cuda::__ensure_current_context context_guard{device};

  // Types
  {
    STATIC_REQUIRE(cuda::std::is_same_v<typename cudax::kernel_ref<void()>::value_type, CUkernel>);
  }

  // Default constructor
  {
    STATIC_REQUIRE(!cuda::std::is_default_constructible_v<cudax::kernel_ref<void()>>);
  }

  // Constructor from kernel handle
  {
    STATIC_REQUIRE(!cuda::std::is_convertible_v<CUkernel, cudax::kernel_ref<void()>>);
    STATIC_REQUIRE(cuda::std::is_constructible_v<cudax::kernel_ref<void()>, CUkernel>);

    // We currently have no way to check if the kernel parameters match
    {
      cudax::kernel_ref<void()> kernel_ref{kernel_ptx1_handle};
      CUDAX_REQUIRE(kernel_ptx1_handle == kernel_ref.get());

      cudax::kernel_ref<void()> kernel_ref2{kernel_ptx2_handle};
      CUDAX_REQUIRE(kernel_ptx2_handle == kernel_ref2.get());

      CUDAX_REQUIRE(kernel_ptx1_handle != kernel_ptx2_handle);
      CUDAX_REQUIRE(kernel_ref.get() != kernel_ref2.get());
    }
    {
      cudax::kernel_ref<void(int*, int)> kernel_ref{kernel_ptx1_handle};
      CUDAX_REQUIRE(kernel_ptx1_handle == kernel_ref.get());

      cudax::kernel_ref<void(int*, int)> kernel_ref2{kernel_ptx2_handle};
      CUDAX_REQUIRE(kernel_ptx2_handle == kernel_ref2.get());

      CUDAX_REQUIRE(kernel_ptx1_handle != kernel_ptx2_handle);
      CUDAX_REQUIRE(kernel_ref.get() != kernel_ref2.get());
    }
  }

  // Constructor from kernel function
#if _CCCL_CTK_AT_LEAST(12, 1)
  {
    STATIC_REQUIRE(cuda::std::is_constructible_v<cudax::kernel_ref<void(int*, int)>, decltype(kernel_rt)>);
    STATIC_REQUIRE(cuda::std::is_convertible_v<decltype(kernel_rt), cudax::kernel_ref<void(int*, int)>>);
    STATIC_REQUIRE(!cuda::std::is_constructible_v<cudax::kernel_ref<void()>, decltype(kernel_rt)>);

    CUkernel kernel_rt_handle{};
    CUDAX_REQUIRE(cudaGetKernel(&kernel_rt_handle, kernel_rt) == cudaSuccess);

    cudax::kernel_ref<void(int*, int)> kernel_ref1{kernel_rt};
    CUDAX_REQUIRE(kernel_rt_handle == kernel_ref1.get());
  }
#endif // _CCCL_CTK_AT_LEAST(12, 1)

  // Copy constructor
  {
    STATIC_REQUIRE(cuda::std::is_trivially_copy_constructible_v<cudax::kernel_ref<void()>>);

    cudax::kernel_ref<void(int*, int)> kernel_ref1{kernel_ptx1_handle};
    CUDAX_REQUIRE(kernel_ptx1_handle == kernel_ref1.get());

    cudax::kernel_ref<void(int*, int)> kernel_ref2{kernel_ref1};
    CUDAX_REQUIRE(kernel_ptx1_handle == kernel_ref2.get());
    CUDAX_REQUIRE(kernel_ref1.get() == kernel_ref2.get());
  }

  // Name
#if _CCCL_CTK_AT_LEAST(12, 3)
  {
    STATIC_REQUIRE(
      cuda::std::is_same_v<decltype(cuda::std::declval<cudax::kernel_ref<void()>>().name()), cuda::std::string_view>);

    cudax::kernel_ref<void(int*, int)> kernel_ref{kernel_ptx1_handle};
    CUDAX_REQUIRE(kernel_ref.name() == "kernel_ptx1");
  }
#endif // _CCCL_CTK_AT_LEAST(12, 3)

  // Attributes
  {
    cudax::kernel_ref<void(int*, int)> kernel_ref{kernel_ptx1_handle};

    const auto cc = cuda::device_attributes::compute_capability(device);

    test_kernel_attribute<cudax::kernel_attributes::max_threads_per_block, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, int>(
      kernel_ref, device);
    test_kernel_attribute<cudax::kernel_attributes::shared_memory_size,
                          CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES,
                          cuda::std::size_t>(kernel_ref, device, sizeof(int) * 32);
    test_kernel_attribute<cudax::kernel_attributes::const_memory_size,
                          CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES,
                          cuda::std::size_t>(kernel_ref, device, sizeof(int));
    test_kernel_attribute<cudax::kernel_attributes::local_memory_size,
                          CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES,
                          cuda::std::size_t>(kernel_ref, device);
    test_kernel_attribute<cudax::kernel_attributes::num_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, int>(kernel_ref, device);
    test_kernel_attribute<cudax::kernel_attributes::virtual_arch, CU_FUNC_ATTRIBUTE_PTX_VERSION, cuda::arch_id>(
      kernel_ref, device, cuda::arch_id::sm_75);
    test_kernel_attribute<cudax::kernel_attributes::binary_arch, CU_FUNC_ATTRIBUTE_BINARY_VERSION, cuda::arch_id>(
      kernel_ref, device, cuda::to_arch_id(cc));
    test_kernel_attribute<cudax::kernel_attributes::cache_mode_ca, CU_FUNC_ATTRIBUTE_CACHE_MODE_CA, bool>(
      kernel_ref, device, false);
    test_kernel_attribute<cudax::kernel_attributes::requires_cluster_dims,
                          CU_FUNC_ATTRIBUTE_CLUSTER_SIZE_MUST_BE_SET,
                          bool>(kernel_ref, device, false);
  }

  // Get handle
  {
    STATIC_REQUIRE(cuda::std::is_same_v<decltype(cuda::std::declval<cudax::kernel_ref<void()>>().get()), CUkernel>);

    cudax::kernel_ref<void(int*, int)> kernel_ref{kernel_ptx1_handle};
    CUDAX_REQUIRE(kernel_ptx1_handle == kernel_ref.get());
  }

  // Equality/Inequality comparison
  {
    cudax::kernel_ref<void(int*, int)> kernel_ref1{kernel_ptx1_handle};
    cudax::kernel_ref<void(int*, int)> kernel_ref2{kernel_ptx2_handle};

    CUDAX_REQUIRE(kernel_ref1 == kernel_ref1);
    CUDAX_REQUIRE(kernel_ref1 != kernel_ref2);
  }

  // Deduction guidelines
#if _CCCL_CTK_AT_LEAST(12, 1)
  {
    cudax::kernel_ref kernel_ref1{kernel_rt};
    CUDAX_REQUIRE((cuda::std::is_same_v<decltype(kernel_ref1), cudax::kernel_ref<void(int*, int)>>) );
  }
#endif // _CCCL_CTK_AT_LEAST(12, 1)

  CUDAX_REQUIRE(_CUDA_DRIVER::__libraryUnloadNoThrow(lib) == cudaSuccess);
}
